{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2f3d58ac-fee3-41fc-969a-c7d92ac8002c",
   "metadata": {},
   "source": [
    "# 使用Mindspore完成图像分类\n",
    "\n",
    "本节主要介绍使用mindspore进行图像分类的简单过程。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c670d88-1d23-41d4-8151-fd267d342321",
   "metadata": {},
   "source": [
    "本文使用的数据集为MNIST手写数字识别的数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cccea3e9-fb6b-4938-b5f0-a6f73f6c0920",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "9913344B [00:01, 6020341.43B/s]                              \n",
      "29696B [00:00, 44835871.70B/s]          \n",
      "1649664B [00:00, 10285702.86B/s]                              \n",
      "5120B [00:00, 13142494.79B/s]          \n"
     ]
    }
   ],
   "source": [
    "from mindvision.dataset import Mnist\n",
    "\n",
    "# 下载并处理MNIST数据集\n",
    "download_train = Mnist(path=\"./mnist\", split=\"train\", batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)\n",
    "\n",
    "download_eval = Mnist(path=\"./mnist\", split=\"test\", batch_size=32, resize=32, download=True)\n",
    "\n",
    "dataset_train = download_train.run()\n",
    "dataset_eval = download_eval.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8429a09b-ae1f-470e-b638-76bc7d810db9",
   "metadata": {},
   "source": [
    "所使用的网络仍未LeNet5网络结构，网络设置在多层神经网络中已有介绍"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e6b79a5a-8393-4610-9ea7-cc2382f51aa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "\n",
    "class LeNet5(nn.Cell):\n",
    "    \"\"\"\n",
    "    LeNet-5网络结构\n",
    "    \"\"\"\n",
    "    def __init__(self, num_class=10, num_channel=1):\n",
    "        super(LeNet5, self).__init__()\n",
    "        # 卷积层，输入的通道数为num_channel，输出的通道数为6，卷积核大小为5*5\n",
    "        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')\n",
    "        # 卷积层，输入的通道数为6，输出的通道数为16，卷积核大小为5*5\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')\n",
    "        # 全连接层，输入个数为16*5*5，输出个数为120\n",
    "        self.fc1 = nn.Dense(16 * 5 * 5, 120)\n",
    "        # 全连接层，输入个数为120，输出个数为84\n",
    "        self.fc2 = nn.Dense(120, 84)\n",
    "        # 全连接层，输入个数为84，分类的个数为num_class\n",
    "        self.fc3 = nn.Dense(84, num_class)\n",
    "        # ReLU激活函数\n",
    "        self.relu = nn.ReLU()\n",
    "        # 池化层\n",
    "        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        # 多维数组展平为一维数组\n",
    "        self.flatten = nn.Flatten()\n",
    "\n",
    "    def construct(self, x):\n",
    "        # 使用定义好的运算构建前向网络\n",
    "        x = self.conv1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "network = LeNet5(num_class=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2680794c-60a2-49d9-8fee-c829388c8751",
   "metadata": {},
   "source": [
    "定义损失函数和优化器，然后使用mindspore的模型参数保存进行训练，最终得到初始化模型参数，通过模型训练得到训练集合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "be493eaa-d33a-4a90-9094-2d265d3f69a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "\n",
    "# 定义损失函数\n",
    "net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')\n",
    "\n",
    "# 定义优化器函数\n",
    "net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b327a059-e06c-489b-8db7-d84a45440b49",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig\n",
    "\n",
    "# 设置模型保存参数，模型训练保存参数的step为1875\n",
    "config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)\n",
    "\n",
    "# 应用模型保存参数\n",
    "ckpoint = ModelCheckpoint(prefix=\"lenet\", directory=\"./lenet\", config=config_ck)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fece4bd4-6387-4e79-8b80-96e150ff50a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:[  0/ 10], step:[ 1875/ 1875], loss:[2.312/2.312], time:4248.603 ms, lr:0.01000\n",
      "Epoch time: 5703.426 ms, per step time: 3.042 ms, avg loss: 2.312\n",
      "Epoch:[  1/ 10], step:[ 1875/ 1875], loss:[2.315/2.315], time:2658.369 ms, lr:0.01000\n",
      "Epoch time: 2661.980 ms, per step time: 1.420 ms, avg loss: 2.315\n",
      "Epoch:[  2/ 10], step:[ 1875/ 1875], loss:[0.208/0.208], time:2912.382 ms, lr:0.01000\n",
      "Epoch time: 2915.887 ms, per step time: 1.555 ms, avg loss: 0.208\n",
      "Epoch:[  3/ 10], step:[ 1875/ 1875], loss:[0.075/0.075], time:2994.525 ms, lr:0.01000\n",
      "Epoch time: 2996.868 ms, per step time: 1.598 ms, avg loss: 0.075\n",
      "Epoch:[  4/ 10], step:[ 1875/ 1875], loss:[0.005/0.005], time:2795.298 ms, lr:0.01000\n",
      "Epoch time: 2797.520 ms, per step time: 1.492 ms, avg loss: 0.005\n",
      "Epoch:[  5/ 10], step:[ 1875/ 1875], loss:[0.005/0.005], time:2919.796 ms, lr:0.01000\n",
      "Epoch time: 2925.434 ms, per step time: 1.560 ms, avg loss: 0.005\n",
      "Epoch:[  6/ 10], step:[ 1875/ 1875], loss:[0.015/0.015], time:2862.038 ms, lr:0.01000\n",
      "Epoch time: 2864.114 ms, per step time: 1.528 ms, avg loss: 0.015\n",
      "Epoch:[  7/ 10], step:[ 1875/ 1875], loss:[0.014/0.014], time:2810.667 ms, lr:0.01000\n",
      "Epoch time: 2812.954 ms, per step time: 1.500 ms, avg loss: 0.014\n",
      "Epoch:[  8/ 10], step:[ 1875/ 1875], loss:[0.005/0.005], time:2775.033 ms, lr:0.01000\n",
      "Epoch time: 2776.391 ms, per step time: 1.481 ms, avg loss: 0.005\n",
      "Epoch:[  9/ 10], step:[ 1875/ 1875], loss:[0.000/0.000], time:2733.139 ms, lr:0.01000\n",
      "Epoch time: 2735.367 ms, per step time: 1.459 ms, avg loss: 0.000\n"
     ]
    }
   ],
   "source": [
    "from mindvision.engine.callback import LossMonitor\n",
    "from mindspore.train import Model\n",
    "\n",
    "# 初始化模型参数\n",
    "model = Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={'accuracy'})\n",
    "\n",
    "# 训练网络模型，并保存为lenet-1_1875.ckpt文件\n",
    "model.train(10, dataset_train, callbacks=[ckpoint, LossMonitor(0.01, 1875)])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "97eeb0ef-d433-4367-adae-d4ced8e509d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'accuracy': 0.9899839743589743}\n"
     ]
    }
   ],
   "source": [
    "acc = model.eval(dataset_eval)\n",
    "\n",
    "print(\"{}\".format(acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "40484b97-4297-49a5-b0c3-b13b9cb65c53",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from mindspore import load_checkpoint, load_param_into_net\n",
    "\n",
    "# 加载已经保存的用于测试的模型\n",
    "param_dict = load_checkpoint(\"./lenet/lenet-1_1875.ckpt\")\n",
    "# 加载参数到网络中\n",
    "load_param_into_net(network, param_dict)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "263687f9-040f-4729-a378-eede00c8265c",
   "metadata": {},
   "source": [
    "对模型进行测试，得到最终的测试结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "860c819e-2380-4441-9554-41ad78ce0fd6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD6CAYAAAC4RRw1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAjt0lEQVR4nO2debBV1bXuvyGCoKCAkAMCCgoasAURJaJRkYTYFFgxxCYGoyWJ9W6ayrUilUrF8tW7KdN5jXWtF0lhICnbgAkk8kRECHaAh0YiEASMCEgroCAKovP9cTbjfHtn73M25+xurvX9qnbxndXOtcfakzXHGmNMCyFACCFEfBxV7QYIIYRoGerAhRAiUtSBCyFEpKgDF0KISFEHLoQQkaIOXAghIqVVHbiZjTazNWa2zswmlqpRorrIrslFtk0W1tI4cDNrA+BNAKMAbALwGoAbQwirStc8UWlk1+Qi2yaPo1ux7zAA60IIbwGAmT0BYAyAgjeDmSlrqEYIIViBVbJrxDRhV+AIbSu71hQ7Qwjdcxe2xoXSC8BG+ntTZlkWZjbBzOrNrL4V5xKVQ3ZNLs3aVnatWTbkW9iaJ/CiCCFMAjAJ0P/oSUJ2TSaya1y05gl8M4A+9HfvzDIRN7JrcpFtE0ZrOvDXAAwws35m1g7ADQBmlqZZoorIrslFtk0YLXahhBAOmdl/AJgNoA2AR0IIK0vWshrlqKOy/88bPny469GjR7s+cOCA6+eee8714sWLy9i61pNWu6YB2TZ5tMoHHkKYBWBWidoiagTZNbnItslCmZhCCBEpZY9CSQLsNundu3fWujvuuMP1iBEjXC9fvtz1okWLytc4IURq0RO4EEJEijpwIYSIFLlQiuDooxu/Jo40AYCRI0e6fvvtt13Pnz/f9fr168vWNiFEetETuBBCRIo6cCGEiBS5UArQtm1b1/3793d98803F9zu6aefdj19+nTXW7ZsKUcThRApR0/gQggRKerAhRAiUtSBCyFEpMgHXoDjjz/e9Ze//GXXw4YNy9rujTfecL106VLX27ZtK2PrRLUwa5zwpl27dq75XUibNm1c8330ySefuN63b5/rjz76KOscn376aWkaKxKPnsCFECJS1IELIUSkyIVSAB4Gn3DCCa55CA0Udpt89tlnZWydKDW5dmX7s6ukU6dOrvv0aZzc5qSTTnLduXNn16NGjXLN98dLL73keuHChVnn3r59u+uk3Ef8/RZyNxWzL2+fW5u/uX2B7O/z0KFDrg8ePJh3ea2jJ3AhhIgUdeBCCBEpcqEUoGPHjq4vueQS13v37s3a7i9/+Yvrd955p+ztEqWDh9ccLQIAQ4cOdX3TTTe5HjJkiOvPfe5zeffnoT0XQguhcZL322+/3fXDDz+cde7f//73rtesWdPMVcQBuyGvvvpq1+ecc47rQu6Uuro61wMHDnR98sknu2a3DJPrQtm8uXEO51deecX1Y489lnc5T41Yi+gJXAghIkUduBBCRIpcKESXLl1cX3jhha7PO+8817lDqnfffdd1bkKGqA2OOeYY17169XLNbpI777wzax+eOq9bt26ujz32WNc7d+50zQldHMVw3HHHuWZ3AUeqcAQLkB3pkhTYDfKjH/3IdY8ePVwXiiph9wjbknWuq6QQHTp0cN2zZ0/X/fr1cz1t2jTXf/vb31xv2rSpqHNUEj2BCyFEpDTbgZvZI2a23czeoGVdzWyOma3N/NulqWOI2kN2TS6ybXooxoUyBcD/APgDLZsIYG4I4T4zm5j5++7SN6+ydO3a1fX555/vmiMMOMkCyK5vwVEGETAFKbErw5EjEydOdH3qqadmbcfRH8uXL3e9YsUK1+w24fuCa5mwO+Q73/mO6yuvvNL16aefnnVudiu0kCmoMduye5G/N/7eOWGKk5vYTcnuKY4I4204miX3u2U3FrvJ+PfO7h7W9957L2qNZp/AQwgLAOzKWTwGwNSMngpgbGmbJcqN7JpcZNv00NKXmHUhhMPTzGwFUFdoQzObAGBCC88jKovsmlyKsq3sGhetjkIJIQQzK+g7CCFMAjAJAJrarhbgKJRzzz3XNddPeOGFF7L2yU3sSQpJsivDLq+1a9e6fuaZZ7K2K1TjZuPGjXmXcy0NhqNQrr32WtfsCuBhOpCd9FIOmrJtuezKETuTJk1yze7Js846yzVPSbho0SJun+v9+/e73rNnj2uOZsl1R3Hyz8iRI11/6Utfcj1o0CDXZ599dr7LqRlaGoWyzcx6AkDm3+3NbC/iQHZNLrJtAmlpBz4TwPiMHg9gRmmaI6qM7JpcZNsE0qwLxcweB3AZgG5mtgnAPQDuA/CUmd0OYAOAceVsZDnh4RYnV3DCBw+P//AHfrEPvPfee+VrXBlJul0ZjgrhZIwHH3zQ9VtvvZW1z44dO1wXU16UIx840uWyyy5zPXjwYNeceJIb2cSz9bSEWrQtuzteffVV13/9619d9+3b1zV/n1xjaMuWLTgScusTLV682DXXqeFkPU7qKabcbTVptgMPIdxYYNXIAstFBMiuyUW2TQ/KxBRCiEhJfS0UfuPPM6xw9AAPw/iNOJA9NBS1CbtAOHKktRNPsxuE62qMHTvW9be//W3Xn//8511zosqsWbOyjpuUErKFYHssWLDANU8efvHFF7tev3696zlz5rguV+2hYuuq1AJ6AhdCiEhRBy6EEJGSehfKKaec4voLX/iCa56Rh4e4TUUkcBQL11XhGg/scuE36pxgImoXLifLZWZ58uK77rrLNbtW2BXw0EMPuX7uueeyzrFrV24WfLLgxLh//vOfru+//37XX/3qV11zVA+7PFvrQmE3KZem5aizWneR6glcCCEiRR24EEJESipdKPyWmWsjcN0DdmlwqcrckrE8DLvuuutc33zzza779+/v+uWXX3b9s5/9zDVHHhSTOCIqB88IwzM1jRvXmAvDLhR2m7z55puuv/a1r7nmxKFCdVTSBv82WJcSTtzj0rIcgfb222+7XrZsWVnaUSr0BC6EEJGiDlwIISIllS4ULmHJZWO5VsWqVatcc/0TflsNAA8//LDryy+/3DW7Vj7++GPXPNTmIdwPf/hD1y+++GLefUXl4DoZ119/vetbb73VNbtTOGqJ3SPf+9738i6X26Q68Aw9AwYMcM19As/AxLoW0RO4EEJEijpwIYSIFHXgQggRKan0gXNYH/vEuLgRZ8dx2OFPf/rTrGNdcMEFrufNm+eai/Tw1F2c+cn+0QceeMD1z3/+c9czZ850zdNGpRmu1c6zu3M9afZhMxyi+f7777vOrTN92223ueYiVJxhe+DAAddc1/q3v/2tay5+Jr939fnmN7/p+swzz3TN98KSJUtcsw+c7ynOuuY5AXLDjMuNnsCFECJS1IELIUSkpNKFMnz4cNcXXXSRax4Kbd261fU999zj+otf/GLWsZ599lnXf/rTn1xzGCJn8rEL5cQTT3TdpUsX1zy0yy10lDRyay/z0JSnI+NMOS46xjOId+rUKe9xOfuOh7hNFS1i11r79u3ztp2n22LXChdcYttzaCln+uYWMqv0MDxNcGEsLkzG0+7x73LMmDGuhw4d6prDDjdu3Oj6N7/5jesjnf6tJegJXAghIkUduBBCREoqXSg8XOLaxDy8Gj16tGvO0MwdFj399NOuecbrDz/80DVPFXXDDTe45mE3t4OjTbitSYFdDz169Mhax9msPB0Zu1Y++OAD11xPeufOna7ZjXHWWWe55ggkHkLnwjPDc+QKRyKwa4VdOXfeeafrkSMb5xFmG8+fP981Ry8B/z5LvWieQu4sdlsBQF1dnWvOqma7Xnrppa6HDBnimt2f/BvlgmWV/r3qCVwIISKl2Q7czPqY2TwzW2VmK83s+5nlXc1sjpmtzfzbpbljidpBdk0msmu6KMaFcgjAf4YQlppZJwBLzGwOgFsBzA0h3GdmEwFMBHB3+ZpaOnbs2OGah6s8vOah1+rVq11zYSsAWLlypWtOMOFIEnabjBgxwvXu3btd87Rtf//7313zUL7EVM2uHFFyzTXXZK3jt/4nnXSSa/5OuEYzJ9NwJA8n3zCFZqjnxCsA2LRpk2t2lfCwu0OHDq45cqHQ9jy85kiTEkedJO73yrBrjO+j0047zTW7PPmeALLdaYWii9hdt3fvXtezZ892zW6TJ5980nWlk+2afQIPIWwJISzN6L0AVgPoBWAMgKmZzaYCGFumNooyILsmE9k1XRzRS0wz6wtgMIBFAOpCCIff6G0FUFdgnwkAJrSijaLMyK7JRHZNPkV34GbWEcB0AD8IIXzAiRIhhGBmeceBIYRJACZljlETGQo8jOZhLSd88Davvvqqa67VDWS/pea31xzFwq4ZDvpntwnPyP3OO++45siFclANu/Iw9hvf+EbWOnZ98Pfz2GOPueZIAq7Jze4provCtly3bp3rF154wfXkyZOz2sGJXHxfsObhPE/Nx7XB6+vr87aDh9q5iTylIEm/V07Q4iS8a6+91jXX1uca/2wLIDvBi793Trx75ZVXXPNvnzVPs1jNmeuLikIxs7ZouBkeDSEcjpvbZmY9M+t7AlDsU2TIrslEdk0PxUShGIDJAFaHEO6nVTMBjM/o8QBmlL55olzIrslEdk0XxbhQLgZwC4B/mNnyzLIfA7gPwFNmdjuADQDG5d+99uCgfU4AYLiGxRlnnOGakzSA7EQNjkLh4RnPOM/lYR988EHXu3btKqrtJaRqdu3evbvrfv36Za1jlxG7t2688UbXY8eOdc1JG1wSdP369a65PChHs7BbpiWRIJw4xOerMon7vXI0Ev/eOIKJbc9RSlw3B8j+7XMU0qOPPuqaI814m1qk2Q48hPASACuwemSB5aLGkV2TieyaLpSJKYQQkZLKWiicdMFB+1xelN98X3XVVQWPxftw5MLzzz/vetq0aa75TXZaZ9hh91S7du2y1nEC1de//vW8+7N7iuuiPPPMM65/97vfuX755Zfz7ivigOsPcTINz3TFdYg2b96cdxsgO3KIk7fYtVbrbhNGT+BCCBEp6sCFECJSUulC4SQPrjXCQ3BOyuFIh48++ijrWDNmNEZjcWnZhQsXumbXShLLwx4pPNzlSBwAGDhwYN59uOYJz1LEdW24PgUnQ3ECjYgPdpNxGV52gbDLk6NT2F2XC7taqhAFVhL0BC6EEJGiDlwIISIllS4UHrZzDQyOUOAhGSf75LpAeHjH5WG5PoLcJtnwLDoPPfRQ1rpCJT450Ya/c3aPsJtFbpNkUoxd2eXJtWuAbBco1z9hN2dM6AlcCCEiRR24EEJESipdKLkTE4vK8vHHH+fVQrQULhNbqL4RkB25wjNtVbMkbGvQE7gQQkSKOnAhhIgUdeBCCBEpqfSBCyHSQW6d95deesk1T4tW7qkLy4WewIUQIlLUgQshRKTIhSKESA0cQhxr6CCjJ3AhhIgUdeBCCBEp6sCFECJS1IELIUSkNNuBm1l7M1tsZq+b2UozuzezvJ+ZLTKzdWb2pJm1a+5YonaQXZOJ7JouiolCOQDgihDCPjNrC+AlM/t/AH4I4L9DCE+Y2W8B3A7g/5axraK0yK7JJJV25Trh27dvd71ixYqs7XjqtCTU6W/2CTw0cHjiyLaZTwBwBYBpmeVTAYwtRwNFeZBdk4nsmi6K8oGbWRszWw5gO4A5ANYD2BNCOPzf3iYAvQrsO8HM6s2svgTtFSVEdk0msmuKCCEU/QHQGcA8ACMArKPlfQC8UcT+QZ/a+MiuyfzIron91Oez0RFFoYQQ9qDhhhgOoLOZHfah9waw+UiOJWoH2TWZyK7Jp5golO5m1jmjOwAYBWA1Gm6M6zObjQcwo0xtFGVAdk0msmu6sNxyi/+2gdk5aHjp0QYNHf5TIYT/bWanAngCQFcAywB8I4RwoPCRADPbAeBDADtL0PbY6Ibaue5TAIxEae26AbV1jZWilq5Zdi0dtXbNp4QQuucubLYDLzVmVh9CGFrRk9YAabjuNFxjLmm45jRcYy6xXLMyMYUQIlLUgQshRKRUowOfVIVz1gJpuO40XGMuabjmNFxjLlFcc8V94EIIIUqDXChCCBEp6sCFECJSKtqBm9loM1uTKWk5sZLnrhRm1sfM5pnZqkw5z+9nlnc1szlmtjbzb5dqt7VUpMGuQPpsK7vWvl0r5gM3szYA3kRDZtgmAK8BuDGEsKoiDagQZtYTQM8QwlIz6wRgCRoqv90KYFcI4b7Mj6FLCOHu6rW0NKTFrkC6bCu7xmHXSj6BD0NDQZ23QggH0ZAVNqaC568IIYQtIYSlGb0XDWnMvdBwrVMzmyWpnGcq7AqkzrayawR2rWQH3gvARvq7YEnLpGBmfQEMBrAIQF0IYUtm1VYAddVqV4lJnV2BVNhWdo3ArnqJWSbMrCOA6QB+EEL4gNeFBr+V4jcjRbZNJjHatZId+GY01CE+TGJLWmamspoO4NEQwtOZxdsyvrbDPrfthfaPjNTYFUiVbWXXCOxayQ78NQADMpOrtgNwA4CZFTx/RTAzAzAZwOoQwv20aiYayngCySrnmQq7AqmzrewagV0rmolpZlcBeAANpS4fCSH8V8VOXiHMbASAFwH8A8BnmcU/RoNP7SkAJ6OhROe4EMKuvAeJjDTYFUifbWXX2rerUumFECJS9BJTCCEiRR24EEJESqs68LSk2qYN2TW5yLYJI99U9cV80PBiYz2AUwG0A/A6gEHN7BP0qY2P7JrMTyl/s9W+Fn2yPjvy2ag1T+CpSbVNGbJrcpFt42VDvoWt6cCLSrU1swlmVm9m9a04l6gcsmtyada2smtcHF3uE4QQJiEzPZGZhXKfT1QG2TWZyK5x0Zon8FSl2qYI2TW5yLYJozUdeGpSbVOG7JpcZNuE0WIXSgjhkJn9B4DZaEy1XVmylomqILsmF9k2eVS6Fop8ajVCCMFKdSzZtXaQXRPLkhDC0NyFysQUQohIUQcuhBCRog5cCCEiRR24EEJEStkTeYQQopI0TLDTQNeuXbPWXXnlla53797tesWKFa63bt1axtaVFj2BCyFEpKgDF0KISJELRQiRKI4+urFbGzFiRNa6n/zkJ67nzp3r+t1333UtF4oQQoiyow5cCCEiRR24EEJEinzgQohEwT7wK664ImvdSSed5Jp93e+99175G1YG9AQuhBCRog5cCCEiRS4UkXg4M69NmzZlOQeXZf7ss8/yLheVgb//t99+O2vdgQMHXA8YMMB1XV2d6y1btpSvcSVGT+BCCBEp6sCFECJS5EIRiadbt26uBw8e7JpdK61l3759rv/1r3+55gw/URnYTXb66adnrWvfvr3rDh06uG7Xrl35G1YG9AQuhBCRog5cCCEiJREulLZt27q+7rrrXPft29f1wIEDXffr18/1CSeckPeYHD3AQ+3cqIJCw/DXX3/d9bx581wvXLjQ9Zo1a/LuKxrg7/b44493zdEDZ599tuvu3bu7PvXUU12z2+TEE0/Me/zW8sknn7jmKAa2/VtvveV6wYIFebfPPZZoHU1FHbF7a8+ePRVoTenRE7gQQkRKsx24mT1iZtvN7A1a1tXM5pjZ2sy/XcrbTFFqZNfkItumh2JcKFMA/A+AP9CyiQDmhhDuM7OJmb/vLn3zGunUqVPW3zx0/spXvuL6qquucs3D7s6dO+c9VqG3z611ofTp08f1+eef73rt2rWu58yZ4/qPf/yja45oKCNTUAN2zeWooxqfKXg6rLvvbmzGkCFDXPfo0cP1Mccc45ptzMfhIXUpXSh8X5x88smu2V23d+9e1yNHjnT98MMPZx1r2bJlrjnx5AiYghq0bTlhW7LtL7jggqztjjvuONfcJxx77LHla1wZafYJPISwAMCunMVjAEzN6KkAxpa2WaLcyK7JRbZNDy19iVkXQjj85mUrgLpCG5rZBAATWngeUVlk1+RSlG1l17hodRRKCCGYWcGCDyGESQAmAUBT2+WDI0QuvfTSrHUTJjTeY+ecc47r3r17u961q/EhZMOGDa757fP69etdc92EYl0oPIS//vrrXffv3981D+E5MoYjInjY/Nprr7n+9NNPUQ3Kadem4KHs8OHDXfN3yzbm0qHVhO8RThA55ZRT8m7PZU1zXTmPPPKIa45aOnjwYKvbCTRt23LZtdyw641daezOArJdphx5sn///vI1roy0NAplm5n1BIDMv9tL1yRRRWTX5CLbJpCWduAzAYzP6PEAZpSmOaLKyK7JRbZNIM2OP83scQCXAehmZpsA3APgPgBPmdntADYAGFeqBnFSzllnneX6hhtuyNru6quvds2JD5w4wa6I1atXu968ebPrQi4Uhoe4uYk/48Y1XnoxEQP8FpxdAR07dsx7vnJRabsWgu0NZLuYbrnlFtfscqgVt0lrYLfaNddck7Vu27Ztrtndt27duqKOXSu2rSTsNjnttNNcN1XjhGvW7NixozwNKzPN/hJCCDcWWDWywHIRAbJrcpFt04MyMYUQIlJqbizKroRhw4a5vvzyy7O2++ijj1wvWbLE9S9/+UvXixcvdv3++++75hk7CsHDdJ6tY9SoUVnb3XHHHa75jfemTZtcc/JIz549XR86dMg1vxFP0ywu7FICgDPPPNM1u8lyXS1HAn+fHNVTbM0RnvCW7zsetrNrja+pGHcPRyMB2RFXS5cudV2sCyWNcPQSJ+80VQuFo3r4txgTegIXQohIUQcuhBCRUnMuFB5+9urVyzUnzADZiTm/+tWvXL/wwguuixki89CcayiwO4TdN3fddVfW/jx04yiW+fPnu2a3CSek8Btyrp3CpWirlchTKXKjerjsb2vqU3A9GXaBcHIXR3vkwpFAy5cvd82lXzmShEvWXnTRRa65xG2x8HcwaNCgI94/jbCrpEuXxjpdTUV08b3H9xrXrKl19AQuhBCRog5cCCEipeZcKBwxwG+Gc98Ss3uEh8jsluDhU6HIE3aVcNTDTTfd5Hro0KGuOQoByHbZPPTQQ6537tzp+lvf+lbednDEzSWXXOL62WefdZ302VnYdQQAY8eObfGx+N7h6I3Jkye75sikN99803VTkUnsZuOhOkeYXHjhha55CN8SFwq3JekutFLBv3Wui9KUC4XdYZqRRwghREVRBy6EEJFScy4Urv1QX1/vmmuZANkJH3/+859d8+wmPETmugdc7pNrmbDbhOHJaH/xi19krZs1a5br7dsbC7zxW22ut7J7927Xlah5UuvkJvJwzZMjhWvRrFixwvXs2bNds0uCa9Fw4lVuu3j2HE4u43K3POsS318tge831qIw7DrlEr7sTkkiyb46IYRIMOrAhRAiUtSBCyFEpNScD5xDwRYsWOB6ypQpWdv9+te/ds0ZcTzVGof8ceEa9ouxv5J95tOmTXP9/PPP590GKFwkq1AmIGd5ceZnWskt9tQa/zH7QTkckafc4/cOXIwqN4yM17FfnmcyZz95KWc1Zz+9wgiLg79/zoTNLWbF4cj83cZaQE5P4EIIESnqwIUQIlJqzoXCcDYjuzQAYOvWra7PPfdc18W4JbjoFIcX8jG5WBYvLzYzkt0pPGwrphZ5msgtHMS24Sn1ioFdY1xArFA2JG+fm+nL6zgTsxJhaeyW++CDD8p+viTArjjuA3JDdZctW+aaQ5ZjzXjWE7gQQkSKOnAhhIiUmnahsLth48aNWeumT5/ummefb9++fbPHZdcM60pPq8TDPh7mpylDc//+/Vl/87D2SF0oDEcfNDWt1mFaM2VbqWGXHRdcEq2Ha8B/+OGHrhMbhWJmfcxsnpmtMrOVZvb9zPKuZjbHzNZm/u3S3LFE7SC7JhPZNV0U40I5BOA/QwiDAFwE4H+Z2SAAEwHMDSEMADA387eIB9k1mciuKaJZF0oIYQuALRm918xWA+gFYAyAyzKbTQUwH8DdZWllHjgxp9Zn62Y3DUe3cDElLs5VzJC/tdSKXXPf/sc0nVWpeOedd7L+5in1cotsNUet2FVUhiPygZtZXwCDASwCUJe5WQBgK4C6AvtMADAh3zpRG8iuyUR2TT5FR6GYWUcA0wH8IISQFZwaGt4A5H0LEEKYFEIYGkIYmm+9qC6yazKRXdNBUU/gZtYWDTfDoyGEpzOLt5lZzxDCFjPrCWB74SOkG44kWLNmjWue7b5v376uK1XDuBbsmusy4e+Ha8jwNGXVrPHM0QrFTH3G2+zatcs1R0PMnTs3a5958+a53rFjxxG3sRbsWivkRnRxLRuOWOPtYopIKSYKxQBMBrA6hHA/rZoJYHxGjwcwo/TNE+VCdk0msmu6KOYJ/GIAtwD4h5ktzyz7MYD7ADxlZrcD2ABgXP7dRY0iuyYT2TVFFBOF8hKAQpklIwssF0dIpYdttWLX3ESVGTMaHwx5xnqe9T23BG0lYTux+6fQrObsWlm+fLnrl19+2fXChQuz9mH3ypHeF7Vi10rDSXhsC3a9AcDAgQNd9+rVyzWXD/7444/L0MLyoFR6IYSIFHXgQggRKTVdCyXp8Jtv1rVUl6Pc8KxJAFBfX+/6u9/9rushQ4a4rub3wy4NTsrKnalJVBYuuzt79mzXt912W9Z2J554ouvTTz/ddbdu3VwfafJUNdETuBBCRIo6cCGEiBS5UKoID8c5suLKK690PWvWLNdpmJ2FE1/4el988cVqNKdJNLtS7VDovmmKM844w3WPHj1cy4UihBCi7KgDF0KISJELpYoUikLhN+LVTFqpNuxiqvRsSSIuuCwxz2jUlJuLXS0xJe8wegIXQohIUQcuhBCRkt7xeQ1QKAqF635wbRAuRyqEaITr0syfP9917mxHxx13nGueDD130vRY0BO4EEJEijpwIYSIFHXgQggRKfKBV4D333/fNfvauJATF2jq37+/a65TLITIz8GDB11zYbF77703azv+nS1YsMB1rFnOegIXQohIUQcuhBCRIhdKBeDZ1VeuXOl63bp1rgcNGuSaZ87msKfc2dhVTEmIf2f//v2uH3/88Sq2pPzoCVwIISJFHbgQQkSKXCgVgId0PA3X2rVrXZ955pmuuZgV69yIlNzpyIQQ6aLZJ3Aza29mi83sdTNbaWb3Zpb3M7NFZrbOzJ40s3blb64oFbJrMpFd00UxLpQDAK4IIZwL4DwAo83sIgA/B/DfIYT+AHYDuL1srRTlQHZNJrJrimi2Aw8N7Mv82TbzCQCuADAts3wqgLHlaGDSOHTokH/27dvnn6OOOso/3bp180/37t39065du6xPa5Bdk4nsmi6KeolpZm3MbDmA7QDmAFgPYE8I4XCV/U0AehXYd4KZ1ZtZfQnaK0qI7JpMZNf0UFQHHkL4NIRwHoDeAIYB+HyxJwghTAohDA0hDG1ZE0W5kF2TieyaHo4oCiWEsMfM5gEYDqCzmR2d+V+9N4DN5Whg0tixY4frRYsWub7llluq0RwAsmtSkV2TTzFRKN3NrHNGdwAwCsBqAPMAXJ/ZbDyAGXkPIGoS2TWZyK7popgn8J4ApppZGzR0+E+FEP5mZqsAPGFm/wfAMgCTy9hOUXpk12Qiu6YI42m9yn4ysx0APgSws2InrR26oXau+5QQQvdSHSxj1w2orWusFLV0zbJr6ai1a85r24p24ABgZvVpfEGShutOwzXmkoZrTsM15hLLNasWihBCRIo6cCGEiJRqdOCTqnDOWiAN152Ga8wlDdechmvMJYprrrgPXAghRGmQC0UIISJFHbgQQkRKRTtwMxttZmsyNYknVvLclcLM+pjZPDNblanH/P3M8q5mNsfM1mb+7VLttpaKNNgVSJ9tZdfat2vFfOCZzLA30ZDauwnAawBuDCGsqkgDKoSZ9QTQM4Sw1Mw6AViChtKdtwLYFUK4L/Nj6BJCuLt6LS0NabErkC7byq5x2LWST+DDAKwLIbwVQjgI4AkAYyp4/ooQQtgSQlia0XvRUIeiFxqudWpmsyTVY06FXYHU2VZ2jcCulezAewHYSH8XrEmcFMysL4DBABYBqAshbMms2gqgrlrtKjGpsyuQCtvKrhHYVS8xy4SZdQQwHcAPQggf8LrQ4LdS/GakyLbJJEa7VrID3wygD/2d2JrEZtYWDTfCoyGEpzOLt2V8bYd9btur1b4Skxq7AqmyrewagV0r2YG/BmCANcyO3Q7ADQBmVvD8FcHMDA2lOleHEO6nVTPRUIcZSFY95lTYFUidbWXXCOxa6XKyVwF4AEAbAI+EEP6rYievEGY2AsCLAP4B4LPM4h+jwaf2FICT0VCic1wIYVdVGlli0mBXIH22lV1r365KpRdCiEjRS0whhIgUdeBCCBEp6sCFECJS1IELIUSkqAMXQohIUQcuhBCRog5cCCEi5f8DfzEb/8l2OBcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 6 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted: \"[1 1 1 1 1 1]\", Actual: \"[6 2 7 9 0 1]\"\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from mindspore import Tensor\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "mnist = Mnist(\"./mnist\", split=\"train\", batch_size=6, resize=32)\n",
    "dataset_infer = mnist.run()\n",
    "ds_test = dataset_infer.create_dict_iterator()\n",
    "data = next(ds_test)\n",
    "images = data[\"image\"].asnumpy()\n",
    "labels = data[\"label\"].asnumpy()\n",
    "\n",
    "plt.figure()\n",
    "for i in range(1, 7):\n",
    "    plt.subplot(2, 3, i)\n",
    "    plt.imshow(images[i-1][0], interpolation=\"None\", cmap=\"gray\")\n",
    "plt.show()\n",
    "\n",
    "# 使用函数model.predict预测image对应分类\n",
    "output = model.predict(Tensor(data['image']))\n",
    "predicted = np.argmax(output.asnumpy(), axis=1)\n",
    "\n",
    "# 输出预测分类与实际分类\n",
    "print(f'Predicted: \"{predicted}\", Actual: \"{labels}\"')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
